#!/usr/bin/python3
# coding=utf-8
#
# Copyright (C) 2023-2024. Huawei Technologies Co., Ltd. All rights reserved.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# ===============================================================================

import numpy as np
import os

def mish(x):
    """Mish激活函数：f(x) = x * tanh(ln(1 + e^x))"""
    # 计算softplus: ln(1 + e^x)
    softplus = np.log1p(np.exp(x))
    # 计算tanh(softplus)
    tanh_softplus = np.tanh(softplus)
    # 计算Mish: x * tanh(softplus)
    return x * tanh_softplus

def gen_golden_data_simple():
    # 生成输入数据（范围设为-10到10，覆盖Mish函数的主要变化区域）
    input_data = np.random.uniform(low=-10.0, high=10.0, size=(8, 2048)).astype(np.float32)
    
    # 计算Mish函数
    golden = mish(input_data).astype(np.float32)
    
    # 保存输入和输出
    os.makedirs("input", exist_ok=True)
    input_data.tofile("input/input_x.bin")
    
    os.makedirs("output", exist_ok=True)
    golden.tofile("output/golden.bin")

if __name__ == "__main__":
    gen_golden_data_simple()